#collapse
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.integrate import solve_ivp
from ipywidgets import interactive
N=3
xs0 = np.linspace(0,1,N)
ys0 = np.zeros(N)
k=1.0  # spring coefficient
m=1.0  # masses
fig,ax=plt.subplots()
ax.plot(xs0,ys0,'o-');
dxs0 = np.diff(xs0)
dys0 = np.diff(ys0)
ls0 = np.sqrt(dxs0**2+dys0**2)
def calculate_spring_forces(xs,ys):
    dxs = np.diff(xs)
    dys = np.diff(ys)
    fxs = k*(dxs-dxs0)
    fys = k*(dys-dys0)
    return fxs,fys
calculate_spring_forces(xs=xs0, ys=ys0)
(array([0., 0.]), array([0., 0.]))
def f(x0=0.0,y0=0.0,x1=0.5,y1=0.0,x2=1.0,y2=0.0):
    
    fig,ax=plt.subplots()
    xs = [x0,x1,x2]
    ys = [y0,y1,y2]
    ax.plot(xs,ys,'o-');
    ax.set_xlim(-0.5,1.5)
    ax.set_ylim(-1,1)
    plt.show()

    
interactive_plot = interactive(f, x0=(0.0,1.0), x1=(0.0,1), x2=(0.0,1.0), y0=(0.0,1.0), y1=(0.0,1.0), y2=(0.0,1.0))
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot
xs_=np.array(xs0)
xs_[0]-=1.0
fxs,fys = calculate_spring_forces(xs=xs_, ys=ys0)
fxs
array([1., 0.])
fys
array([0., 0.])
np.roll(fxs,1)
array([0., 1.])
np.roll(fxs,2)
array([1., 0.])
fxs_ = np.concatenate([[0],fxs,[0]])
fxs_
array([0., 1., 0., 0.])
fys_ = np.concatenate([[0],fys,[0]])
fys_
array([0., 0., 0., 0.])
fx_left = fxs_[0:-1]
fx_left
array([0., 1., 0.])
fx_right = fxs_[1:]
fx_right
array([1., 0., 0.])
def spring_force_to_point(fs):
    
    fs_ = np.concatenate([[0],fs,[0]])
    f_1 = fs_[0:-1]
    f_2 = fs_[1:]
    return f_1, f_2

def calculate_point_forces(fxs,fys):
    
    fx_left, fx_right = spring_force_to_point(fxs)
    fy_bottom, fy_top = spring_force_to_point(fys)
    
    fx = -fx_left + fx_right
    fy = -fy_bottom + fy_top
    
    return fx,fy
    
calculate_point_forces(fxs,fys)
(array([ 1., -1.,  0.]), array([0., 0., 0.]))
def f(x0=0.0,y0=0.0,x1=0.5,y1=0.0,x2=1.0,y2=0.0):
    
    fig,ax=plt.subplots()
    xs = [x0,x1,x2]
    ys = [y0,y1,y2]
    
    fxs,fys = calculate_spring_forces(xs=xs, ys=ys)
    
    fx,fy = calculate_point_forces(fxs=fxs, fys=fys)
    
    for x,y,fx_,fy_ in zip(xs,ys,fx,fy):
        ax.arrow(x,y,fx_,fy_, head_width=0.05, head_length=0.1,)
    
    ax.plot(xs,ys,'o-');
    
    ax.set_xlim(-0.5,1.5)
    ax.set_ylim(-1,1)
    plt.show()
    ax.set_title('Point forces')

    
interactive_plot = interactive(f, x0=(0.0,1.0), x1=(0.0,1), x2=(0.0,1.0), y0=(0.0,1.0), y1=(0.0,1.0), y2=(0.0,1.0))
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot
def update(t, states):
    
    xs = states[0:N]
    ys = states[N:2*N]
    
    dxs = states[2*N:3*N]
    dys = states[3*N:]
    
    
    fxs,fys = calculate_spring_forces(xs=xs, ys=ys)
    fx,fy = calculate_point_forces(fxs=fxs, fys=fys)
    
    ddxs = fx/m
    ddys = fy/m
    
    dstates = np.concatenate([dxs,dys,ddxs,ddys])
    
    return dstates
    
    
dxs=[0,0,0]
dys=[0,0,0]
states = np.concatenate([xs0,ys0,dxs,dys])
update(0, states = states)
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
#states_0 = np.concatenate([xs_,ys0,dxs,dys])

p = np.random.rand(1,4)
xlim=[-0.5,1.5]
ylim=[-1,1]


states_0 = np.concatenate([xs_,ys0,dxs,dys])


Ns=200
t = np.linspace(0,20,Ns)
result = solve_ivp(fun = update,t_span=[t[0],t[-1]],t_eval=t, y0 = states_0)
from matplotlib.animation import FuncAnimation
from matplotlib import animation, rc
from IPython.display import HTML

fig, ax = plt.subplots()
ln, = ax.plot([], [], '-ro')

def init():
    ax.set_xlim(-0.5,1.5)
    ax.set_ylim(-1,1)
    return ln,

def update(frame):
    
    xs = result.y[0:N,frame]
    ys = result.y[N:2*N,frame]
        
    ln.set_data(xs, ys)
    return ln,

anim = FuncAnimation(fig, update, frames=np.arange(Ns),
                    init_func=init, blit=True)
#anim
plt.rcParams["animation.html"] = "jshtml"
HTML(anim.to_jshtml())
</input>